{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\";   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text Classification with Hugging Face Transformers in *ktrain*\n",
    "\n",
    "As of v0.8.x, *ktrain* now includes an easy-to-use, thin wrapper to the Hugging Face transformers library for text classification."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data Into Arrays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "size of training set: 2257\n",
      "size of validation set: 1502\n",
      "classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n"
     ]
    }
   ],
   "source": [
    "categories = ['alt.atheism', 'soc.religion.christian',\n",
    "             'comp.graphics', 'sci.med']\n",
    "from sklearn.datasets import fetch_20newsgroups\n",
    "train_b = fetch_20newsgroups(subset='train',\n",
    "   categories=categories, shuffle=True, random_state=42)\n",
    "test_b = fetch_20newsgroups(subset='test',\n",
    "   categories=categories, shuffle=True, random_state=42)\n",
    "\n",
    "print('size of training set: %s' % (len(train_b['data'])))\n",
    "print('size of validation set: %s' % (len(test_b['data'])))\n",
    "print('classes: %s' % (train_b.target_names))\n",
    "\n",
    "x_train = train_b.data\n",
    "y_train = train_b.target\n",
    "x_test = test_b.data\n",
    "y_test = test_b.target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 1: Preprocess Data and Build a Transformer Model\n",
    "\n",
    "For `MODEL_NAME`, *ktrain* supports both the \"official\" built-in models [available here](https://huggingface.co/transformers/pretrained_models.html) and the [community-upoaded models available here](https://huggingface.co/models)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using Keras version: 2.2.4-tf\n",
      "preprocessing train...\n",
      "language: en\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocessing test...\n",
      "language: en\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import ktrain\n",
    "from ktrain import text\n",
    "MODEL_NAME = 'distilbert-base-uncased'\n",
    "t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)\n",
    "trn = t.preprocess_train(x_train, y_train)\n",
    "val = t.preprocess_test(x_test, y_test)\n",
    "model = t.get_classifier()\n",
    "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that `x_train` and `x_test` are the raw texts that look like this:\n",
    "```python\n",
    "x_train = ['I hate this movie.', 'I like this movie.']\n",
    "```\n",
    "The labels are arrays in one of the following forms:\n",
    "```python\n",
    "# string labels\n",
    "y_train = ['negative', 'positive']\n",
    "# integer labels\n",
    "y_train = [0, 1]\n",
    "# multi or one-hot encoded labels\n",
    "y_train = [[1,0], [0,1]]\n",
    "```\n",
    "In the latter two cases, you must supply a `class_names` argument to the `Transformer` constructor, which tells *ktrain* how indices map to class names.  In this case, `class_names=['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']` because 0=alt.atheism, 1=comp.graphics, etc."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 2 [Optional]: Estimate a Good Learning Rate\n",
    "\n",
    "Learning rates between `2e-5` and `5e-5` tend to work well with transformer models based on papers from Google. However, we will run our learning-rate-finder for two epochs to estimate the LR on this particular dataset.\n",
    "\n",
    "As shown below, our results are consistent Google's findings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "simulating training for different learning rates... this may take a few moments...\n",
      "Train for 376 steps\n",
      "Epoch 1/2\n",
      "376/376 [==============================] - 73s 194ms/step - loss: 1.0788 - accuracy: 0.5191\n",
      "Epoch 2/2\n",
      "115/376 [========>.....................] - ETA: 43s - loss: 1.9950 - accuracy: 0.2482\n",
      "\n",
      "done.\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEMCAYAAADJQLEhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU5dn/8c+VfU8ICSGQhH1flbCoqKiouFIVF1yqVqVardraPtYuaref1ta2T12LPJa64a4gqKAFBReEsENADHtIIPu+z9y/P2bAmCYhgTlzZrner1dezJxz5sx1J2S+Oec+577FGINSSqngFWJ3AUoppeylQaCUUkFOg0AppYKcBoFSSgU5DQKllApyGgRKKRXkLAsCEXleRIpEZGsH6xNF5D0R2SQi20TkZqtqUUop1TErjwjmAzM6WX8nkGuMGQdMAx4XkQgL61FKKdWOMKt2bIxZKSL9O9sEiBcRAeKAMqDlWPtNSUkx/ft3tlullFJtrVu3rsQYk9reOsuCoAueBBYBBUA8cLUxxtnehiIyB5gDkJWVRU5OjteKVEqpQCAi+zpaZ2dn8fnARqAPMB54UkQS2tvQGDPXGJNtjMlOTW030JRSSh0nO4PgZuBt45IH7AGG21iPUkoFJTuDYD9wDoCIpAHDgN021qOUUkHJsj4CEVmA62qgFBHJBx4CwgGMMc8Cvwfmi8gWQID7jTElVtWjlFKqfVZeNTT7GOsLgPOsen+llFJdo3cWK6VUkNMgUEopP/Bx7mG+OVxtyb41CJRSyg/86OX1vLX+oCX71iBQSik/4DCGsBCxZN8aBEop5eOMMTichlANAqWUCk4OpwHQIFBKqWDVokGglFLBzWlcQaB9BEopFaT0iEAppYKcw6FBoJRSQc2hp4aUUiq4fXvVkDUf2RoESinl477tI7Bm/xoESinl477tI9AjAqWUCkraR6CUUkHO4XQCetWQUkoFLb2PQCmlglyL3keglFLBTYeYUEqpIKenhpRSKsjpMNRKKRXkNAiUUirIHQmCML2hTCmlgpPfDjEhIs+LSJGIbO1km2kislFEtonIp1bVopRS/uzbG8r874hgPjCjo5UikgQ8DVxqjBkFXGlhLUop5bccrhzwv8tHjTErgbJONrkWeNsYs9+9fZFVtSillD8L5CEmhgI9ROQTEVknIt+3sRallPJZVt9HEGbJXrv+3hOAc4Bo4EsRWW2M2dl2QxGZA8wByMrK8mqRSillt0C+fDQfWGqMqTXGlAArgXHtbWiMmWuMyTbGZKempnq1SKWUstu3l48GXhAsBKaKSJiIxACTge021qOUUj7pyKmhEPGzU0MisgCYBqSISD7wEBAOYIx51hizXUQ+BDYDTmCeMabDS02VUipYHT0iCPWzIDDGzO7CNn8G/mxVDUopFQgCuY9AKaVUF+gQE0opFeSOXj5qUR+BBoFSSvm4ozeUWdRHoEGglFI+zm+HmFBKKeUZR44IrLp8VINAKaV8XEsA31CmlFKqC5xOgwiEaBAopVRwanEay44GQINAKaV8nsNpLOsfAA0CpZTyeXpEoJRSQc7hNJYNLwEaBEop5fM0CJRSKsi1OI1lE9eDBoFSSvm8xhYH4RYNLwEaBEop5fP2lNTSr2eMZfvXIFBKKR/mdBp2HqpmeO8Ey95Dg0AppXzYwYp6apscDOsdb9l7aBAopZQPW7+/HIBRffSIQCmlgtLyHUWkxEUwuk+iZe+hQaCUUj7KGMNn35RwxpBUywacAw0CpZTyWYeqGiitbWJcZpKl76NBoJRSPiq3oAqAkRb2DwCEWbp3pZRSXVbf5OCVNftZsaOIhmYHuYWuIBiRrkGglFJB4a31+fx+cS4Aw3vHU9fk4Pun9CMu0tqPasv2LiLPAxcDRcaY0Z1sNxH4ErjGGPOmVfUopZSvO1hRD8BLt0zmtME92VVcy8CUWMvf18o+gvnAjM42EJFQ4E/AMgvrUEopv1BQUU9WcgxTh6QgIgzuFWfp1UJHWBYExpiVQNkxNvsx8BZQZFUdSinlLwoq6umTFOX197Wtj0BE+gKXAWcBE+2qoz1NLU52HKqiZ1wk+WV1lNU2ERkewulDUgkP/W52Vjc0Ex8V3uV9G2MQC6ecU0r5jz0ltcRHhREVHkpcZBgFFQ1MHpjs9Trs7Cz+O3C/McZ5rA9GEZkDzAHIysryyJvXNbXwr8/34nQaLhiTTmZyNGv2lPFR7mHe2XCQ6oaW/3pNSlwE4zN70NDsoKSmkaLqRspqmxiWFk9Di4PE6HB6xUdRVd9MZHgI6/aVExUeSq/4SDKTYyisrKe4upGBKXFU1jczLjOR80b1prHZQa+EKNITo+idEIWIUFnXjMHgcLqCIzk2AnANRxsWEmLpJBVKKWvVNbXw8KJtvJ6Tf3RZVnIMByvqyUiK9no9Yoyxbuci/YHF7XUWi8ge4MinWQpQB8wxxrzb2T6zs7NNTk5Ot2tpdjgpr2vioYXbCA0RlmwppHXTQwScBkTge+P7Mm1YKgUVDQxMjaVvUjSFlQ28kXOAfaV1REeEkhQTTnpiNA3NDvaU1JKeGMX2wirqmx2kxkfS4jBk9++BMVBU3ciOQ1U4ndDkcJKWEElKXCSrd5fS0Oz8Tp3R4aHERYVRVtuEw/ltgQlRYfSIjaCwsgEMTB6YzJlDU6lqaMHhdLJhfwWFlQ0M6RV3NDzSEiKpbmhh68FKGpodDEmLp7axhYr6ZmIjw5gyMJnJA5KJjQgjLiqMQalxRIWHdvt7q5TqmtKaRvaW1nHHS+soqm7kuslZJMdGEBYSQm5hJZFhoTx4yUhS4iI9/t4iss4Yk93uOruCoM12893bHfOqoeMNgv9sP8wt/3a9LiYilLomB/fPGM5lJ/Xlg62FbMmv5MIx6YzNSKRXgnfO0ZXWNLLzcA1xkWEU1zRwsKKBvMPVFFY2MCAllvioMMJDQwgR4UB5HeV1zcRHhRETHspb6/Mpr2sGXCE2sk8CWckxbDpQSYvTSY+YCIqqG3Eaw5QBPYkIC+GbohqSosNJigmnvK6JnL3ltLQKm9AQIbtfDzJ6xBAdEcLYjCQq65qJCAuhqLqB2kYH8VFhDOsdT0xEKCPSE0iNiyQsVO9LVOpYFqzZzy/f2YIxrrML/7xhAhP6ee80kC1BICILgGm4/to/DDwEhAMYY55ts+18LA6CvKIaFqzZT3piFLeePpCy2qajp1v8kdNpqKxvPvoXfHRE9/+SL6ttYl9pLXVNDirqmtlysJKVO4upqGuitLaJxpZvj1ZCQ4To8FBqm1q+cyQVFxnG2IxEhvWOp8VhKK1tZPqINMZlJjGgZ6xXrnhQylcZYyirbeI/O4p4eNE2EqLCSYgO4yfTh3LBmHSv1mLbEYEVjjcIVPfUNrZwuKqBHjERlNU1kZ4YRUyE65RVSU3j0eDYVVzDtoOVbC+sJjRESIgO43BVI+A6nZXdP5lJA5KZ2D+ZmIhQwkNDGNwrzubWKWWdyrpmnlj+DZ/llbDzcDVHDrqH9Irj5Vsne+2MQ1saBMpyDqfBGEOICOv2l7OnpJb1+8pZs7eM3cW139n21EE9mT0pixmje//XVVhK+TOn0zDr2S/YcKCCkzKTGNwrjhaH4cZT+zO6b6KtF3loEChbFVc3sm5fGS1OQ355Pc+t3E1pbROxEaF876S+XDs5i1EWjrWulNUKKuqZu3I3H+Ue5mBFPY/NGstV2Zl2l/UdGgTKpzidhmW5h1i0qYAVO4qpb3Zw0dh0Hrl8DAnduCdDKTut3FnMuxsPcrC8njV7yxDg3JFpTBnYkxtP6e9z/WMaBMpnVdY1M++z3TyxPI+UuAieuX4CE/t7/4Yapbpj4caD3PPqRhKjw+mbFM30kWnMOjmDrJ4xdpfWIQ0C5fM2HajgzlfWk19ez3kj05g6JIVrJ2XppanKp9Q0tnDbv3P4cncpE/v34MVbJvvNvTedBYEOQ618wrjMJJbeewY/e2MTH2w9xLLcw+wvrePXF4+0uzSlANck8tc99xX1zQ5+NG0Qd58zxG9C4Fg0CJTPiI0M45nrJ+BwGh5etI15n+3hzGGpnD4k1e7SVJB7c10+Dy7cSn2zgzvPGsTPzx9ud0kepUGgfE5oiPCLC4azencpt8zP4YoJfbnjzME+ff5VBZaSmkY27K9g2bZDbM6v5OvD1QzvHc+T154ckPfBaB+B8lklNY3c9kIOG/ZXkJUcw9J7zziuO6iV6o7//fgb/vc/O3EaCA8VRqYn0DMukkcuH0OaTTeDeYL2ESi/lBIXyTs/Oo0v8kq4dt5XPLdqN3efM8TuslQAO1hRz5MrvmHqkFTuPnswo/okBsUfHxoEyuedOjiFi8am8/ePd9I7McrnbtRRgaGkppHLnvocgN9dOor+Xpgi0ldoECi/8NgVY6msa+Z/3txMXGQYF3p5wC4VuCrrmqltauGHL66jvK6JF34wOahCADQIlJ+IjQxj3o3ZzH5uNXe+sp67zhrMfecNs7ss5aeMMazfX8Fb6/N55av9AMRGhPLnWeM4ZVBPm6vzPg0C5TeiwkN5+dbJ3Pf6Jp5ckcfM8X0D8goOZR1jDC+t3scLX+7jm6IaRGDSgGR6xkbw/VP6B2UIgAaB8jMxEWH8/nujWb6jiHmrdvPoFWPtLkn5kWW5h/nNwm2MzUjk4UtGcs6INDKT9bJkDQLld1LiIrkyO4PX1+bz03OH2ja+u/IvWw9Wcv9bmxmaFsfbd5yqw5e0ot8J5ZdunTqQZqeTF1fvs7sU5QfW7CnjunlfERsRxrzvT9QQaEO/G8ov9U+J5Zzhaby0eh8NzQ67y1E+yOk0vLepgAv/dxVX/fNL4iLDeHXOFL1DvR0aBMpv3Xr6AMrrmnn6k112l6J8TG1jC79bnMuPF2zgm6Jqfn3RCD6893TtD+iA9hEovzV5QDKXndSXf/znG84dkcaYDJ3lTEF+eR2znvmSQ1UNTB6QzN+vGU96YrTdZfk0PSJQfktE+O3MUSTFhPObhVv1FJGirqmFe1/dSHVDMy/fOpkFt03REOgCDQLl1xKiwnnksjFsPFDB3z7eaXc5yib1TQ7mf76HWc98yfr95Tw2axynDU7xuekifZUGgfJ7F4xJZ9aEDJ7/bA/F1Y12l6O8bMP+cmY9+wUPv5dLRV0Tf7t6PBeN1SFIukP7CFRAuGPaIN5cl89ra/dz19k6QmkwKKys549LtrN4cyFxkWHM+34200em2V2WX7LsiEBEnheRIhHZ2sH660Rks4hsEZEvRGScVbWowDcoNY6pg1N45av9tDicdpejvOCeVzeyeHMh10zMZPnPztQQOAFWnhqaD8zoZP0e4ExjzBjg98BcC2tRQeD6Kf0oqGxg+Y4iu0tRFtt0oII1e8r4zcUjefSKsfSK17vLT4RlQWCMWQmUdbL+C2NMufvpaiDDqlpUcJg+ohe94iN5de0Bu0tRFiqrbWLeZ3uIDAvhymz92PAEX+ksvgX4wO4ilH8LCw1h9qQslu8o4q5X1lPfpJeTBppfvbOFk3//Ee9tKmDWhAwSosLtLikg2N5ZLCJn4QqCqZ1sMweYA5CVleWlypQ/uvX0Aby4eh+LNxcyMDWOn5471O6SlIe8v6WQl7/azxUnZ3D9lCzGZiTZXVLA6NIRgYjcIyIJ4vJ/IrJeRM470TcXkbHAPGCmMaa0o+2MMXONMdnGmOzU1NQTfVsVwOKjwln9wDlcNCaduSt3UVhZb3dJygO+PlTN3Qs2MC4ziT9eNpqTsnoQqvcIeExXTw39wBhTBZwH9ABuAB49kTcWkSzgbeAGY4zeCaQ8JiIshAcuHI7TCXNX7ra7HHWCjDH8ZuFW4qPCmH/TRKLCA38yeW/rahAcid4LgReNMdtaLWv/BSILgC+BYSKSLyK3iMjtInK7e5MHgZ7A0yKyUURyjqN+pdqV0SOG6SN7sXBjAZX1zXaXo07Awo0FrNlTxv/MGE6P2Ai7ywlIXe0jWCciy4ABwAMiEg90erG2MWb2MdbfCtzaxfdXqttumTqAj3IPc9/rG5l340S7y1HHoaHZwf97fzvjMpO4OjvT7nICVlePCG4BfgFMNMbUAeHAzZZVpZQHTOiXzM2nDeDTncXUNrbYXY46Dm+tz6eoupH7ZwzTcYMs1NUgOAX42hhTISLXA78GKq0rSynPmDY0lWaH4fO8ErtLUd3kcBqeW7mbcRmJnDIwOCeV95auBsEzQJ17GIj7gF3AC5ZVpZSHZPdPJikmnPc2F9pdiuqmD7ceYm9pHT88cxAiejRgpa4GQYsxxgAzgSeNMU8B8daVpZRnRISFcMnYPizbdkhHJvUjxhieXJHHwNRYzh/V2+5yAl5Xg6BaRB7AddnoEhEJwdVPoJTPu/m0/jichqdW5NldiuqipdsOsb2wijvOHKT3C3hBV4PgaqAR1/0Eh3CNC/Rny6pSyoMGpsZx6bg+vLUuX4ed8AMLNx7k7gUbGdIrjpnj+9pdTlDoUhC4P/xfBhJF5GKgwRijfQTKb1w9MZPqxhbe36J9Bb5s44EK7nt9EydlJfHG7acQEeYrw6EFtq4OMXEVsAa4ErgK+EpEZllZmFKeNGlAMv17xvBajo5M6qvqmlr48YL1pCVEMfeGbJJi9OYxb+lq3P4K1z0ENxpjvg9MAn5jXVlKeZaIcGV2Jmv2lLGnpNbuclQ7nvlkFwfK6vnb1eNJjNEuSG/qahCEGGNaz/ZR2o3XKuUTZk3IIDREeF2PCnxOfnkdc1fuZub4PkwakGx3OUGnqx/mH4rIUhG5SURuApYA71tXllKel5YQxVnDUnlzXb5OZ+ljHvlgByJw/4zhdpcSlLraWfxzXFNJjnV/zTXG3G9lYUpZ4arsTIqrG/lM7zT2GV/tLmXJ5kJuP3MQfZKi7S4nKHX59I4x5i1jzE/dX+9YWZRSVjljqGs+i5v+tZZV3xTbXI1yOA2/W5xLn8QofnjGILvLCVqdBoGIVItIVTtf1SJS5a0ilfKUqPBQxme6Zrb6+Rubcd0wr+zypw93sK2gil9cOILoCJ1nwC6dBoExJt4Yk9DOV7wxJsFbRSrlSc9eP4GrszM5VNXAloM6dqJdNudXMHflbmZPyuKSsel2lxPU9MofFXR6J0bxiwuGExoifJR72O5yglJji4MH3t5CcmwEv7xwuA4qZzMNAhWUesRGkN2vBx9sPaSnh2zwwhf72FZQxZ+uGEt8lN4zYDcNAhW0Zo7vS15RjZ4e8rKqhmae+iSPM4amcu7INLvLUWgQqCB20dh0IsNCeHNdvt2lBA1jDL9dlEtFXTP/c/4wu8tRbhoEKmglRodz/qjeLNxYoKOSeskb6/J5a30+d589mNF9E+0uR7lpEKigdsMp/aisb+b5z/fYXUrAK6pq4A+Lc5nUP5l7pw+1uxzVigaBCmoT+yczdXAKr609oJ3GFvvbxztpaHby6BVjdCJ6H6NBoILe+aN7s7+sjl3FOiqpVbYVVPJGTj7XTMpkYGqc3eWoNjQIVNA7e3gvAJbv0HsKrNDY4uCnr22iR2yEnhLyURoEKuj1TYpmeO94lu8oOvbGqtveWX+Qrw9X8+jlY0iO1clmfJFlQSAiz4tIkYhs7WC9iMg/RCRPRDaLyMlW1aLUsUwfkcbaveXkFugQWp72wdZDZCXHHD3yUr7HyiOC+cCMTtZfAAxxf80BnrGwFqU6dcvUASRFh/OXZV/bXUpA2XGois/zSrhgdG8dRsKHWRYExpiVQFknm8wEXjAuq4EkEdGRp5QtesRGcGV2Jp/uLKa0ptHucgLG3z7aSXxUGD88U4eY9mV29hH0BVrPGZjvXvZfRGSOiOSISE5xsY4hr6wxc3wfHE7D+1sP2V1KQKhqaGbFjmIuOylD+wZ8nF90Fhtj5hpjso0x2ampqXaXowLU8N7xDE2LY+GGg3aXEhBW7iymyeHkorG97S5FHYOdQXAQyGz1PMO9TClbiAhXTsgkZ1856/Z1dlZTdcXneSXER4UxPrOH3aWoY7AzCBYB33dfPTQFqDTGFNpYj1JcNyWLHjHhPP/ZXrtL8WvvbjjIgjUHmDKwJ6F6F7HPC7NqxyKyAJgGpIhIPvAQEA5gjHkWeB+4EMgD6oCbrapFqa6KiQjj4rF9eGPdAWobW4iNtOxXJKAt3ebqZ/m5jjDqFyz7X26MmX2M9Qa406r3V+p4XTw2nRdX7+Pj7YeZOb7d6xfUMew8XM25I9MYmhZvdymqC/yis1gpb5rYP5m0hEgWb9YzlcejscXB3tI6hqbpmEL+QoNAqTZCQoSLxvTh06+Lqaxvtrscv5NXVIPDafRowI9oECjVjkvGpdPkcOrk9sfh052ue30mD+hpcyWqqzQIlGrH+MwkMnpE896mArtL8StvrsvnsQ+/ZnjveHonRtldjuoiDQKl2iEiXDy2D5/nlVBW22R3OX5jhXsE119dNMLmSlR3aBAo1YGLx6bT4jR8qENOdNnGAxVcPDad04foCAD+RINAqQ6M6pPAwJRYFm/W00NdsbekloMV9YzPTLK7FNVNGgRKdcB1eiid1btLKapusLscn/fsp7uICAvhknF97C5FdZMGgVKduHhcH5wGlurpoU7ll9fx5rp8Zk/MJC1BO4n9jQaBUp0YmhZP/54xLMs9jOtmeNWel1bvx4DOO+CnNAiUOoazh6ex6psS7nt9k92l+KQWh5N3Nxxk2tBU+iRF212OOg4aBEodw21nDKBPYhQLNxVQVKV9BW29u7GAQ1UNzJ6UZXcp6jhpECh1DOmJ0bx82xQcTsMb6/LtLsen7Cut5XfvbWNEegLnjNDJ6f2VBoFSXTAgJZZTB/VkwZr9OJ3aVwBQ29jC9L9+SlVDCz8+e7BOTu/HNAiU6qJrJ2eRX17PqrwSu0vxCR9vP0yzw3DnWYO4YLROR+nPNAiU6qLzRvYmJS6Ceat2212K7RxOw9yVu8lMjua+c4fp0YCf0yBQqosiwkK4/cxBrPqmhDV7gntO4093FrGtoIr7zh1GiE5F6fc0CJTqhuun9CMpJpznP9tjdym2+nDrIeIjw7hwTLrdpSgP0CBQqhuiwkOZdXIGH28/HLST1ry/pZB3NxRw7qg0IsL0IyQQ6E9RqW66YIxrVNJPvi6yuxSvyy2o4q5X1jMmI5EHLx5pdznKQzQIlOqmkzKTSImLZNm24Ju97OlP8oiNCOP5myaSFBNhdznKQzQIlOqmkBDh3JFpLNlSyOrdpXaX4zXVDc0s23aYKyZkkBgdbnc5yoM0CJQ6Dt8b7xpq+Y6X1tHU4rS5Gu/4ePthmhxOHWY6AGkQKHUcJg/syT9vmEB5XTPLdwTHKaIlmwvpkxjFSTrxTMCxNAhEZIaIfC0ieSLyi3bWZ4nIChHZICKbReRCK+tRypOmj0gjLSGSuSt309jisLscS1XWN7NyZwkXjknX+wYCkGVBICKhwFPABcBIYLaItL3M4NfA68aYk4BrgKetqkcpTwsNES47KYP1+ys447EV7Cmptbsky3yc6zotdLGeFgpIVh4RTALyjDG7jTFNwKvAzDbbGCDB/TgR0MlhlV+548xB/Pz8YZTXNfOvzwP3JrNluYfomxTNuIxEu0tRFrAyCPoCB1o9z3cva+1h4HoRyQfeB37c3o5EZI6I5IhITnFxsRW1KnVcEmPCufOswZw7Io0lmwtpaA7MU0QHK+oZmhanYwoFKLs7i2cD840xGcCFwIsi8l81GWPmGmOyjTHZqampXi9SqWO54ZR+lNY2MXdlYA5IV1bTRHJspN1lKItYGQQHgcxWzzPcy1q7BXgdwBjzJRAFpFhYk1KWmDKwJ+cM78ULX+6j2RFYl5MaYyipbSIlTm8gC1RWBsFaYIiIDBCRCFydwYvabLMfOAdAREbgCgI996P80uxJWZTUNLJiR2ANPVHT2EJTi5OeGgQBy7IgMMa0AHcBS4HtuK4O2iYivxORS92b3QfcJiKbgAXATcYYnf5J+aVpw1LpFR/J/C/24gigWcxKa5oA6KmnhgJWmJU7N8a8j6sTuPWyB1s9zgVOs7IGpbwlLDSEG6b04/GPdvKHJbk8dMkou0vyiNJaVxAk6xFBwLK7s1ipgHLX2YO5dFwfXv5qPyU1jXaX4xGl7nak6BFBwNIgUMqDRIR7pg+h2eHk31/stbscjzhc7QqC1HgNgkClQaCUhw1KjeP8kb159tNdATFnwYGyOiLCQuilQRCwNAiUssCfrhhL36RonlieZ3cpJ2x/aR1ZyTE6xlAA0yBQygKJMeFcP6Uf6/aVs+ob/74iel+ZKwhU4NIgUMoi103ux+Becfzqna1+ezmpMYYDGgQBT4NAKYtER4Tyk+lD2V9W57c3ma3fX0FNYwtj+upgc4FMg0ApC503Ko30xCjm++EVRMYY5n+xl4iwEM4blWZ3OcpCGgRKWSg8NITrp/Tjs7wSFm48iNOPThEt2lTAe5sKuP2MgcRH6RzFgUyDQCmLXT3RNfbiPa9u5JU1+22upmvqmlr4y7KvGZmewL3Th9pdjrKYBoFSFkuJi+TBi12T8y3e7B9zLz376W4OlNXz4CUj9bLRIGDpWENKKZcfTB1AeV0TT63Io6Cinj5J0XaX1Kklmws4bXBPpgzsaXcpygv0iEApLzlyiujlr/bZXEnn9pXWsqu4lnNHaAdxsNAgUMpLMnrEcObQVN5Z79udxhsPVAAwWY8GgoYGgVJeNHN8XwoqG1jpw3cb5xZUEREawuBecXaXorxEg0ApL5oxujf9esbwhyXbffaoILewiqG94wgP1Y+HYKE/aaW8KCrcdbdxXlENX+4utbucdn1zuIahafF2l6G8SINAKS+bMbo3STHhvPKV791T0NTi5HB1A5k9dGyhYKJBoJSXRYWHcsXJGSzddug7s5g5nIZnP91FuXtqSDscqmzAGOjbw7cvb1WepUGglA2uys6kxWn4YEvh0WUbD1Tw6Ac7uPWFHNvqyi+vAyBDgyCoaBAoZYOhaXEM7hXHq2sPUN/kAOBwVQMA6/aVU9XQbEtd+RX1AGQk6amhYKJBoJQNRIR7zhlCbmEVT3/imsWssLLh6PpFG+0ZimJXUQ1hIULvxChb3l/ZQ4NAKZtcMlX+pfkAAA2OSURBVK4PUwen8MTyPLYXVlFYUU9UeAjjMpP445Lt7Cmp9XpNn+4sJrt/DyLC9KMhmOhPWykbXXFyBgA/fHEdhZUN9EmM5p/XT6DF6eSBtzdT09hyQvsvqKgnr6i6S9vmFdWw41A104b1OqH3VP7H0iAQkRki8rWI5InILzrY5ioRyRWRbSLyipX1KOVrZo7vw73Th7C/rI4lWwpJT4qid2IU54/qzerdZTy4cOsx99HicLY7Feb2wiou+scqrnjmS2qPESj1TQ4eeHsz8ZFhR8NJBQ/LgkBEQoGngAuAkcBsERnZZpshwAPAacaYUcC9VtWjlC8SEW46tT/R4aEAXHaS60P4kcvHMG1YKgs3FnCgrK7Tffzo5fVc8L8rKattYsWOIu58eT1bD1Zy7XOrEREq65tZcIx5EN7ekM/aveX8/nujSY2P9EzjlN+w8ohgEpBnjNltjGkCXgVmttnmNuApY0w5gDHGPyd2VeoEJMVEsOJn03j+pmxmTXAFQXxUOI9ePhaH03D6YytYvuNwu69dvbuUZbmH2Xm4hnMe/4Sb569lyZZCLn7iM8JDQ3j7jlOZPCCZeav20NTiPPq6hmYHX+4qZcXXrl+5PcW1RIWHMHN8H+sbrHyOlUHQFzjQ6nm+e1lrQ4GhIvK5iKwWkRnt7UhE5ohIjojkFBf77mBdSh2v3olRnD087b+W/cQ9O9jjy3ZizHdP/9Q3ObjthRyykmN4bc4UesREHF0XHxnGX68aT/+UWG6fNohDVQ3MXbkLp9Pw1e5SpjzyH2Y/t5qb/7WWdzbkk19eT0aPGER0EppgZPfENGHAEGAakAGsFJExxpiK1hsZY+YCcwGys7N9c6QupSxwz/QhpMRH8Kt3trJ2bzmTBiQfXffFrhKqG1p4+rqTmTywJ//+wSTueXUDv5s5mtF9E49ud+aQVAalxvKXZTtZsOYAByvq6RkbwV+vGsezn+7ibx99Q0xEKH19fLIcZR0rjwgOApmtnme4l7WWDywyxjQbY/YAO3EFg1LK7fKTMkiMDufOV9bzy3e28OFW193Iy3cUERsRejQcMpNjePtHp30nBABCQoT5N0/i/hnDOei+YWzG6N5cfnIGPztvGPvL6thxqNrnZ01T1rEyCNYCQ0RkgIhEANcAi9ps8y6uowFEJAXXqaLdFtaklN+Jjgjl71ePp7i6kVe+2s/tL62nxeFk+Y4ipg5JITIs9Jj7yEyO4Y5pg1jzy3OYOjiFm08bAMC5I9O40t0vofMPBC/LgsAY0wLcBSwFtgOvG2O2icjvRORS92ZLgVIRyQVWAD83xvjm2LxK2eis4b24bnLW0edPLM+jsLKBc4Z3bzrJXglRvHTr5KMf+iLCY7PG8vaPTv3O/lVwkbYdUL4uOzvb5OTYNyiXUnZpdjg5VNnAjc+vYbf7ruM1vzqHXvE6HIQ6NhFZZ4zJbm+d3lmslJ8IDw0hMzmGn5zrupJobEaihoDyCLuvGlJKddNFY9JZuu0QZ+lQEMpDNAiU8jMhIcKT155sdxkqgOipIaWUCnIaBEopFeQ0CJRSKshpECilVJDTIFBKqSCnQaCUUkFOg0AppYKcBoFSSgU5vxtrSESKgX3up4lAZTuPU4ASD7xd632eyLYdrWtvedtlHbXRl9vb0frutrft8yOPPdXejmo6nu26+jM+nvaC93/Gnmpve8v88f90ILS3nzEmtd01xhi//QLmdvA4x9P7P5FtO1rX3vK2yzppo8+2t6ttO1Z7O2qzp9rbnTYfT3u72z5f+hl7qr3dbKPP/p8OxPa2/vL3U0PvdfDYiv2fyLYdrWtvedtlHbXRl9vb0frutrftczvbfDztbW95sLW3vWX++H86ENt7lN+dGuoKEckxHQy3Goi0vYEv2Nqs7fUufz8i6MhcuwvwMm1v4Au2Nmt7vSggjwiUUkp1XaAeESillOoiDQKllApyGgRKKRXkNAiUUirIBV0QiMjpIvKsiMwTkS/srsdqIhIiIn8UkSdE5Ea767GaiEwTkVXun/E0u+vxBhGJFZEcEbnY7lq8QURGuH++b4rIHXbXYzUR+Z6IPCcir4nIeVa8h18FgYg8LyJFIrK1zfIZIvK1iOSJyC8624cxZpUx5nZgMfBvK+s9UZ5oLzATyACagXyravUED7XXADVAFMHRXoD7gdetqdKzPPQ7vN39O3wVcJqV9Z4oD7X3XWPMbcDtwNWW1OlPl4+KyBm4fslfMMaMdi8LBXYC5+L6xV8LzAZCgUfa7OIHxpgi9+teB24xxlR7qfxu80R73V/lxph/isibxphZ3qq/uzzU3hJjjFNE0oC/GmOu81b93eWh9o4DeuIKvhJjzGLvVH98PPU7LCKXAncALxpjXvFW/d3l4c+sx4GXjTHrPV1nmKd3aCVjzEoR6d9m8SQgzxizG0BEXgVmGmMeAdo9VBaRLKDSl0MAPNNeEckHmtxPHdZVe+I89fN1KwcirajTUzz0850GxAIjgXoRed8Y47Sy7hPhqZ+xMWYRsEhElgA+GwQe+hkL8CjwgRUhAH4WBB3oCxxo9TwfmHyM19wC/MuyiqzV3fa+DTwhIqcDK60szCLdaq+IXA6cDyQBT1pbmiW61V5jzK8AROQm3EdDllZnje7+jKcBl+MK+vctrcwa3f0d/jEwHUgUkcHGmGc9XVAgBEG3GWMesrsGbzHG1OEKvqBgjHkbV/gFFWPMfLtr8BZjzCfAJzaX4TXGmH8A/7DyPfyqs7gDB4HMVs8z3MsClbZX2xtogq3NPtfeQAiCtcAQERkgIhHANcAim2uykrZX2xtogq3NPtdevwoCEVkAfAkME5F8EbnFGNMC3AUsBbYDrxtjttlZp6doe7W9BFB7Ifja7C/t9avLR5VSSnmeXx0RKKWU8jwNAqWUCnIaBEopFeQ0CJRSKshpECilVJDTIFBKqSCnQaAsJyI1XniPS7s4ZLMn33OaiJx6HK87SUT+z/34JhHxiTGRRKR/2+GS29kmVUQ+9FZNyjs0CJTfcA/f2y5jzCJjzKMWvGdn43FNA7odBMAvsXjsGKsYY4qBQhHx6XkAVPdoECivEpGfi8haEdksIr9ttfxdEVknIttEZE6r5TUi8riIbAJOEZG9IvJbEVkvIltEZLh7u6N/WYvIfBH5h4h8ISK7RWSWe3mIiDwtIjtE5CMRef/IujY1fiIifxeRHOAeEblERL4SkQ0i8rGIpLmHFr4d+ImIbBTXzHepIvKWu31r2/uwFJF4YKwxZlM76/qLyHL39+Y/7uHSEZFBIrLa3d4/tHeEJa5ZypaIyCYR2SoiV7uXT3R/HzaJyBoRiXe/zyr393B9e0c1IhIqIn9u9bP6YavV7wI+O8+DOg7GGP3SL0u/gBr3v+cBcwHB9UfIYuAM97pk97/RwFagp/u5Aa5qta+9wI/dj38EzHM/vgl40v14PvCG+z1G4hr7HWAWrmGLQ4DeuOYsmNVOvZ8AT7d63oNv78K/FXjc/fhh4GettnsFmOp+nAVsb2ffZwFvtXreuu73gBvdj38AvOt+vBiY7X58+5HvZ5v9XgE81+p5IhAB7AYmupcl4BpxOAaIci8bAuS4H/cHtrofzwF+7X4cCeQAA9zP+wJb7P5/pV+e+wrKYaiVbc5zf21wP4/D9UG0ErhbRC5zL890Ly/FNZnOW232c2SY6XW4xqVvz7vGNTZ/rrhmKwOYCrzhXn5IRFZ0UutrrR5nAK+JSDquD9c9HbxmOjBSRI48TxCROGNM67/g04HiDl5/Sqv2vAg81mr599yPXwH+0s5rtwCPi8ifgMXGmFUiMgYoNMasBTDGVIHr6AF4UkTG4/r+Dm1nf+cBY1sdMSXi+pnsAYqAPh20QfkhDQLlTQI8Yoz553cWuiYamQ6cYoypE5FPcE29CNBgjGk7s1qj+18HHf8fbmz1WDrYpjO1rR4/gWvay0XuWh/u4DUhwBRjTEMn+63n27Z5jDFmp4icDFwI/EFE/gO808HmPwEO45rmMgRor17BdeS1tJ11UbjaoQKE9hEob1oK/EBE4gBEpK+I9ML112a5OwSGA1Msev/PgSvcfQVpuDp7uyKRb8eLv7HV8mogvtXzZbhmkwLA/Rd3W9uBwR28zxe4hiQG1zn4Ve7Hq3Gd+qHV+u8QkT5AnTHmJeDPwMnA10C6iEx0bxPv7vxOxHWk4ARuwDVXbltLgTtEJNz92qHuIwlwHUF0enWR8i8aBMprjDHLcJ3a+FJEtgBv4vog/RAIE5HtuOZmXW1RCW/hmhYwF3gJWA9UduF1DwNviMg6oKTV8veAy450FgN3A9nuztVcXOfzv8MYswPXlIPxbdfhCpGbRWQzrg/oe9zL7wV+6l4+uIOaxwBrRGQj8BDwB2NME3A1rqlKNwEf4fpr/mngRvey4Xz36OeIebi+T+vdl5T+k2+Pvs4ClrTzGuWndBhqFVSOnLMXkZ7AGuA0Y8whL9fwE6DaGDOvi9vHAPXGGCMi1+DqOJ5paZGd17MS12Tr5XbVoDxL+whUsFksIkm4On1/7+0QcHsGuLIb20/A1bkrQAWuK4psISKpuPpLNAQCiB4RKKVUkNM+AqWUCnIaBEopFeQ0CJRSKshpECilVJDTIFBKqSD3/wGPjKpB0E20pAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learner.lr_find(show_plot=True, max_epochs=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 3: Train Model\n",
    "\n",
    "Train using a [1cycle learning rate schedule](https://arxiv.org/pdf/1803.09820.pdf)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "begin training using onecycle policy with max lr of 8e-05...\n",
      "Train for 377 steps, validate for 251 steps\n",
      "Epoch 1/4\n",
      "377/377 [==============================] - 89s 236ms/step - loss: 0.5214 - accuracy: 0.8285 - val_loss: 0.2847 - val_accuracy: 0.9081\n",
      "Epoch 2/4\n",
      "377/377 [==============================] - 80s 213ms/step - loss: 0.1524 - accuracy: 0.9513 - val_loss: 0.5775 - val_accuracy: 0.8309\n",
      "Epoch 3/4\n",
      "377/377 [==============================] - 81s 215ms/step - loss: 0.1066 - accuracy: 0.9739 - val_loss: 0.2469 - val_accuracy: 0.9387\n",
      "Epoch 4/4\n",
      "377/377 [==============================] - 81s 215ms/step - loss: 0.0318 - accuracy: 0.9907 - val_loss: 0.1645 - val_accuracy: 0.9561\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f6e7c4d2f28>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.fit_onecycle(8e-5, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 4: Evaluate/Inspect Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                        precision    recall  f1-score   support\n",
      "\n",
      "           alt.atheism       0.92      0.93      0.93       319\n",
      "         comp.graphics       0.97      0.97      0.97       389\n",
      "               sci.med       0.97      0.95      0.96       396\n",
      "soc.religion.christian       0.96      0.96      0.96       398\n",
      "\n",
      "              accuracy                           0.96      1502\n",
      "             macro avg       0.95      0.96      0.95      1502\n",
      "          weighted avg       0.96      0.96      0.96      1502\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[298,   2,   8,  11],\n",
       "       [  7, 378,   3,   1],\n",
       "       [  5,   8, 378,   5],\n",
       "       [ 15,   0,   1, 382]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.validate(class_names=t.get_classes())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------\n",
      "id:521 | loss:7.12 | true:sci.med | pred:comp.graphics)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# the one we got most wrong\n",
    "learner.view_top_losses(n=1, preproc=t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "From: jim.zisfein@factory.com (Jim Zisfein) \n",
      "Subject: Data of skull\n",
      "Distribution: world\n",
      "Organization: Invention Factory's BBS - New York City, NY - 212-274-8298v.32bis\n",
      "Reply-To: jim.zisfein@factory.com (Jim Zisfein) \n",
      "Lines: 11\n",
      "\n",
      "GT> From: gary@concave.cs.wits.ac.za (Gary Taylor)\n",
      "GT> Hi, We are trying to develop a image reconstruction simulation for the skull\n",
      "\n",
      "You could do high resolution CT (computed tomographic) scanning of\n",
      "the skull.  Many CT scanners have an algorithm to do 3-D\n",
      "reconstructions in any plane you want.  If you did reconstructions\n",
      "every 2 degrees or so in all planes, you could use the resultant\n",
      "images to create user-controlled animation.\n",
      "---\n",
      " . SLMR 2.1 . E-mail: jim.zisfein@factory.com (Jim Zisfein)\n",
      "                                                                                                                        \n",
      "\n"
     ]
    }
   ],
   "source": [
    "# understandable mistake - this sci.med post talks a lot about computer graphics\n",
    "print(x_test[521])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 5: Make Predictions on New Data in Deployment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = ktrain.get_predictor(learner.model, preproc=t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'soc.religion.christian'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor.predict('Jesus Christ is the central figure of Christianity.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <style>\n",
       "    table.eli5-weights tr:hover {\n",
       "        filter: brightness(85%);\n",
       "    }\n",
       "</style>\n",
       "\n",
       "\n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "        \n",
       "\n",
       "    \n",
       "\n",
       "        \n",
       "\n",
       "        \n",
       "    \n",
       "        \n",
       "        \n",
       "    \n",
       "        <p style=\"margin-bottom: 0.5em; margin-top: 0em\">\n",
       "            <b>\n",
       "    \n",
       "        y=soc.religion.christian\n",
       "    \n",
       "</b>\n",
       "\n",
       "    \n",
       "    (probability <b>0.998</b>, score <b>7.287</b>)\n",
       "\n",
       "top features\n",
       "        </p>\n",
       "    \n",
       "    <table class=\"eli5-weights\"\n",
       "           style=\"border-collapse: collapse; border: none; margin-top: 0em; table-layout: auto; margin-bottom: 2em;\">\n",
       "        <thead>\n",
       "        <tr style=\"border: none;\">\n",
       "            \n",
       "                <th style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\" title=\"Feature contribution already accounts for the feature value (for linear models, contribution = weight * feature value), and the sum of feature contributions is equal to the score or, for some classifiers, to the probability. Feature values are shown if &quot;show_feature_values&quot; is True.\">\n",
       "                    Contribution<sup>?</sup>\n",
       "                </th>\n",
       "            \n",
       "            <th style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">Feature</th>\n",
       "            \n",
       "        </tr>\n",
       "        </thead>\n",
       "        <tbody>\n",
       "        \n",
       "            <tr style=\"background-color: hsl(120, 100.00%, 80.00%); border: none;\">\n",
       "    <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
       "        +7.336\n",
       "    </td>\n",
       "    <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
       "        Highlighted in text (sum)\n",
       "    </td>\n",
       "    \n",
       "</tr>\n",
       "        \n",
       "        \n",
       "\n",
       "        \n",
       "        \n",
       "            <tr style=\"background-color: hsl(0, 100.00%, 99.40%); border: none;\">\n",
       "    <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
       "        -0.049\n",
       "    </td>\n",
       "    <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
       "        &lt;BIAS&gt;\n",
       "    </td>\n",
       "    \n",
       "</tr>\n",
       "        \n",
       "\n",
       "        </tbody>\n",
       "    </table>\n",
       "\n",
       "    \n",
       "\n",
       "\n",
       "\n",
       "    <p style=\"margin-bottom: 2.5em; margin-top:-0.5em;\">\n",
       "        <span style=\"background-color: hsl(120, 100.00%, 71.39%); opacity: 0.92\" title=\"1.671\">jesus</span><span style=\"opacity: 0.80\"> </span><span style=\"background-color: hsl(120, 100.00%, 67.21%); opacity: 0.95\" title=\"2.031\">christ</span><span style=\"opacity: 0.80\"> </span><span style=\"background-color: hsl(120, 100.00%, 90.67%); opacity: 0.83\" title=\"0.337\">is</span><span style=\"opacity: 0.80\"> </span><span style=\"background-color: hsl(120, 100.00%, 94.71%); opacity: 0.81\" title=\"0.150\">the</span><span style=\"opacity: 0.80\"> </span><span style=\"background-color: hsl(120, 100.00%, 92.54%); opacity: 0.82\" title=\"0.245\">central</span><span style=\"opacity: 0.80\"> </span><span style=\"background-color: hsl(0, 100.00%, 94.75%); opacity: 0.81\" title=\"-0.148\">figure</span><span style=\"opacity: 0.80\"> </span><span style=\"background-color: hsl(120, 100.00%, 88.56%); opacity: 0.83\" title=\"0.451\">of</span><span style=\"opacity: 0.80\"> </span><span style=\"background-color: hsl(120, 100.00%, 60.00%); opacity: 1.00\" title=\"2.698\">christianity</span><span style=\"opacity: 0.80\">.</span>\n",
       "    </p>\n",
       "\n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "    \n",
       "\n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor.explain('Jesus Christ is the central figure of Christianity.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.save('/tmp/my_20newsgroup_predictor')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "reloaded_predictor = ktrain.load_predictor('/tmp/my_20newsgroup_predictor')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'soc.religion.christian'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reloaded_predictor.predict('Jesus Christ is the central figure of Christianity.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "array([8.9553175e-03, 3.1522836e-04, 3.8172584e-04, 9.9034774e-01],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reloaded_predictor.predict_proba('Jesus Christ is the central figure of Christianity.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reloaded_predictor.get_classes()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
